#pragma once
#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/cutlass_utils.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>


#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \
  DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__(); })

using namespace flashinfer;

using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>;
using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>;
using DTypeO = cutlass_dtype_t<{{ dtype_o }}>;
using IdType = cutlass_dtype_t<int32_t>;

constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};

struct Params {
  using DTypeQ = DTypeQ;
  using DTypeKV = DTypeKV;
  using DTypeO = DTypeO;
  using IdType = IdType;

  // The QKV matrices.
  DTypeQ* q_ptr;
  DTypeKV* k_ptr;
  DTypeKV* v_ptr;
  DTypeO* o_ptr;
  float* lse_ptr;

  // Additional params
  struct AdditionalParams {
    {{ additional_params_decl }};
  } additional_params;

  int64_t q_stride_n;
  int64_t k_stride_n;
  int64_t v_stride_n;
  int64_t o_stride_n;
  int64_t q_stride_h;
  int64_t k_stride_h;
  int64_t v_stride_h;
  int64_t o_stride_h;

  int qo_len;
  int kv_len;
  int head_dim;
  int num_qo_heads;
  int num_kv_heads;
  int group_size;
  int window_left;

  bool causal;
};

{{ variant_decl }}
